{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tempfile\n",
    "import glob\n",
    "import shutil\n",
    "import os\n",
    "import numpy as np\n",
    "from netCDF4 import Dataset\n",
    "from wrf import getvar, ll_to_xy, CoordPair, GeoBounds, to_np\n",
    "\n",
    "_VARS_TO_KEEP = (\"Times\", \"XLAT\", \"XLONG\", \"XLAT_U\", \"XLAT_V\", \"XLONG_U\", \n",
    "                \"XLONG_V\", \"U\", \"V\", \"W\", \"PH\", \"PHB\", \"T\", \"P\", \"PB\", \"Q2\", \n",
    "                \"T2\", \"PSFC\", \"U10\", \"V10\", \"XTIME\", \"QVAPOR\", \"QCLOUD\", \n",
    "                \"QGRAUP\", \"QRAIN\", \"QSNOW\", \"QICE\", \"MAPFAC_M\", \"MAPFAC_U\",\n",
    "                \"MAPFAC_V\", \"F\", \"HGT\", \"RAINC\", \"RAINSH\", \"RAINNC\", \"I_RAINC\", \"I_RAINNC\",\n",
    "                \"PBLH\")\n",
    "\n",
    "class FileReduce(object):\n",
    "    def __init__(self, filenames, geobounds, tempdir=None, vars_to_keep=None, \n",
    "                 max_pres=None, compress=False, delete=True, reuse=False):\n",
    "        \"\"\"An iterable object for cutting out geographic domains.\n",
    "        \n",
    "        Args:\n",
    "        \n",
    "            filenames (sequence): A sequence of file paths to the WRF files\n",
    "            \n",
    "            geobounds (GeoBounds): A GeoBounds object defining the region of interest\n",
    "            \n",
    "            tempdir (str): The location to store the temporary cropped data files. If None, tempfile.mkdtemp is used.\n",
    "            \n",
    "            vars_to_keep (sequence): A sequence of variables names to keep from the original file. None for all vars.\n",
    "            \n",
    "            max_press (float): The maximum pressure height level to keep. None for all levels.\n",
    "            \n",
    "            compress(bool): Set to True to enable zlib compression of variables in the output.\n",
    "            \n",
    "            delete (bool): Set to True to delete the temporary directory when FileReduce is garbage collected.\n",
    "            \n",
    "            reuse (bool): Set to True when you want to resuse the files that were previously converted. *tempdir* \n",
    "                must be set to a specific directory that contains the converted files and *delete* must be False.\n",
    "                \n",
    "        \n",
    "        \"\"\"\n",
    "        self._filenames = filenames\n",
    "        self._i = 0\n",
    "        self._geobounds = geobounds\n",
    "        self._delete = delete\n",
    "        self._vars_to_keep = vars_to_keep\n",
    "        self._max_pres = max_pres\n",
    "        self._compress = compress\n",
    "        self._cache = set()\n",
    "        self._own_data = True\n",
    "        self._reuse = reuse\n",
    "        \n",
    "        if tempdir is not None:\n",
    "            if not os.path.exists(tempdir):\n",
    "                os.makedirs(tempdir)\n",
    "            self._tempdir = tempdir\n",
    "            if self._reuse:\n",
    "                self._cache = set((os.path.join(self._tempdir, name) \n",
    "                                   for name in os.listdir(self._tempdir)))\n",
    "        else:\n",
    "            self._tempdir = tempfile.mkdtemp()\n",
    "\n",
    "        print (\"temporary directory is: {}\".format(self._tempdir))\n",
    "        self._prev = None\n",
    "        self._set_extents()\n",
    "    \n",
    "    def _set_extents(self):\n",
    "        fname = list(self._filenames)[0]\n",
    "        with Dataset(fname) as ncfile:\n",
    "            lons = [self._geobounds.bottom_left.lon, self._geobounds.top_right.lon]\n",
    "            lats = [self._geobounds.bottom_left.lat, self._geobounds.top_right.lat]\n",
    "            orig_west_east = len(ncfile.dimensions[\"west_east\"])\n",
    "            orig_south_north = len(ncfile.dimensions[\"south_north\"])\n",
    "            orig_bottom_top = len(ncfile.dimensions[\"bottom_top\"])\n",
    "            \n",
    "            # Note: Not handling the moving nest here\n",
    "            # Extra points included around the boundaries to ensure domain is fully included\n",
    "            x_y = ll_to_xy(ncfile, lats, lons, meta=False)\n",
    "            self._start_x = 0 if x_y[0,0] == 0 else x_y[0,0] - 1\n",
    "            self._end_x = orig_west_east - 1 if x_y[0,1] >= orig_west_east - 1 else x_y[0,1] + 1\n",
    "            self._start_y = 0 if x_y[1,0] == 0 else x_y[1,0] - 1\n",
    "            self._end_y = orig_south_north - 1 if x_y[1,1] >= orig_south_north - 1 else x_y[1,1] + 1\n",
    "            \n",
    "            self._west_east = self._end_x - self._start_x + 1\n",
    "            self._west_east_stag = self._west_east + 1\n",
    "            self._south_north = self._end_y - self._start_y + 1\n",
    "            self._south_north_stag = self._south_north + 1\n",
    "            \n",
    "            # Crop the vertical to the specified pressure\n",
    "            if self._max_pres is not None:\n",
    "                pres = getvar(ncfile, \"pressure\")\n",
    "                # Find the lowest terrain height\n",
    "                ter = to_np(getvar(ncfile, \"ter\"))\n",
    "                min_ter = float(np.amin(ter)) + 1\n",
    "                ter_less = ter <= min_ter\n",
    "                ter_less = np.broadcast_to(ter_less, pres.shape)\n",
    "                # For the lowest terrain height, find the lowest vertical index to meet \n",
    "                # the desired pressure level. The lowest terrain height will result in the \n",
    "                # largest vertical spread to find the pressure level.\n",
    "                x = np.transpose(((pres.values <= self._max_pres) & ter_less).nonzero())\n",
    "                self._end_bot_top = np.amin(x, axis=0)[0] \n",
    "                if (self._end_bot_top >= orig_bottom_top):\n",
    "                    self._end_bot_top = orig_bottom_top - 1\n",
    "            else:\n",
    "                self._end_bot_top = orig_bottom_top - 1\n",
    "                \n",
    "            self._bottom_top = self._end_bot_top + 1\n",
    "            self._bottom_top_stag = self._bottom_top + 1\n",
    "            \n",
    "            print(\"bottom_top\", self._bottom_top)\n",
    "            \n",
    "        \n",
    "    def __iter__(self):\n",
    "        return self\n",
    "    \n",
    "    def __copy__(self):\n",
    "        cp = type(self).__new__(self.__class__)\n",
    "        cp.__dict__.update(self.__dict__)\n",
    "        cp._own_data = False\n",
    "        cp._delete = False\n",
    "        \n",
    "        return cp\n",
    "    \n",
    "    def __del__(self):\n",
    "        if self._delete:\n",
    "            shutil.rmtree(self._tempdir)\n",
    "    \n",
    "    def reduce(self, fname):\n",
    "        outfilename = os.path.join(self._tempdir, os.path.basename(fname))\n",
    "        \n",
    "        # WRF-Python can iterate over sequences several times during a 'getvar', so a cache is used to \n",
    "        if outfilename in self._cache:\n",
    "            return Dataset(outfilename)\n",
    "        \n",
    "        # New dimension sizes\n",
    "        dim_d = {\"west_east\" : self._west_east,\n",
    "                 \"west_east_stag\" : self._west_east_stag,\n",
    "                 \"south_north\" : self._south_north,\n",
    "                 \"south_north_stag\" : self._south_north_stag,\n",
    "                 \"bottom_top\" : self._bottom_top,\n",
    "                 \"bottom_top_stag\" : self._bottom_top_stag\n",
    "                }\n",
    "        \n",
    "        # Data slice sizes for the 2D dimensions\n",
    "        slice_d = {\"west_east\" : slice(self._start_x, self._end_x + 1),\n",
    "                   \"west_east_stag\" : slice(self._start_x, self._end_x + 2),\n",
    "                   \"south_north\" : slice(self._start_y, self._end_y + 1),\n",
    "                   \"south_north_stag\" : slice(self._start_y, self._end_y + 2),\n",
    "                   \"bottom_top\" : slice(None, self._end_bot_top + 1),\n",
    "                   \"bottom_top_stag\" : slice(None, self._end_bot_top + 2)\n",
    "                  }\n",
    "        \n",
    "        with Dataset(fname) as infile, Dataset(outfilename, mode=\"w\") as outfile:\n",
    "            \n",
    "            # Copy the global attributes\n",
    "            outfile.setncatts(infile.__dict__)\n",
    "\n",
    "            # Copy Dimensions, limiting south_north and west_east to desired domain\n",
    "            for name, dimension in infile.dimensions.items():\n",
    "                dimsize = dim_d.get(name, len(dimension))\n",
    "                outfile.createDimension(name, dimsize)\n",
    "\n",
    "            # Copy Variables  \n",
    "            for name, variable in infile.variables.items():\n",
    "                if self._vars_to_keep is not None:\n",
    "                    if name not in self._vars_to_keep:\n",
    "                        continue\n",
    "                \n",
    "                print (name)\n",
    "                new_slices = tuple((slice_d.get(dimname, slice(None)) for dimname in variable.dimensions))\n",
    "\n",
    "                outvar = outfile.createVariable(name, variable.datatype, variable.dimensions, zlib=self._compress)\n",
    "\n",
    "                outvar[:] = variable[new_slices]\n",
    "\n",
    "                outvar.setncatts(variable.__dict__)\n",
    "                \n",
    "        \n",
    "        result = Dataset(outfilename)\n",
    "            \n",
    "        self._cache.add(outfilename)\n",
    "            \n",
    "        return result\n",
    "            \n",
    "    \n",
    "    def next(self):\n",
    "        if self._i >= len(self._filenames):\n",
    "            if self._prev is not None:\n",
    "                self._prev.close()\n",
    "            raise StopIteration\n",
    "        else:\n",
    "            fname = self._filenames[self._i]\n",
    "            reduced_file = self.reduce(fname)\n",
    "            if self._prev is not None:\n",
    "                self._prev.close()\n",
    "            self._prev = reduced_file\n",
    "            \n",
    "            self._i += 1\n",
    "            \n",
    "            return reduced_file\n",
    "    \n",
    "    # Python 3\n",
    "    def __next__(self):\n",
    "        return self.next()\n",
    "\n",
    "# How to use with getvar\n",
    "# Set lower left and upper right to your desired domain\n",
    "# Idaho bounding box: [\"41.9880561828613\",\"49.000846862793\",\"-117.243034362793\",\"-111.043563842773\"]\n",
    "ll = CoordPair(lat=41.8, lon=-117.26)\n",
    "ur = CoordPair(lat=49.1, lon=-110.5)\n",
    "bounds = GeoBounds(ll, ur)\n",
    "reduced_files = FileReduce(glob.glob(\"/Users/ladwig/Documents/wrf_files/boise_tutorial/orig/wrfout_*\"),\n",
    "                           bounds, vars_to_keep=_VARS_TO_KEEP, max_pres=400,\n",
    "                           tempdir=\"/Users/ladwig/Documents/wrf_files/boise_tutorial/reduced\", \n",
    "                           delete=False, reuse=True)\n",
    "\n",
    "pres = getvar(reduced_files, \"pressure\")\n",
    "\n",
    "print(pres)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
